Accelerating Inference in Molecular Diffusion Models with Latent Representations of Protein Structure

Diffusion generative models have emerged as a powerful framework for addressing problems in structural biology and structure-based drug design. These models operate directly on 3D molecular structures. Due to the unfavorable scaling of graph neural networks (GNNs) with graph size as well as the relatively slow inference speeds inherent to diffusion models, many existing molecular diffusion models rely on coarse-grained representations of protein structure to make training and inference feasible. However, such coarse-grained representations discard essential information for modeling molecular interactions and impair the quality of generated structures. In this work, we present a novel GNN-based architecture for learning latent representations of molecular structure. When trained end-to-end with a diffusion model for de novo ligand design, our model achieves comparable performance to one with an all-atom protein representation while exhibiting a 3-fold reduction in inference time.1


Introduction
There has been a surge of interest in leveraging diffusion models to address problems in structurebased drug design.These efforts have yielded promising outcomes, exemplified by successes in de novo ligand design [1], molecular docking [2], fragment linker design [3], and scaffold hopping [4].These models apply diffusion processes on point cloud representations of protein/ligand complexes and employ geometric Graph Neural Networks (GNNs) to make denoising predictions.However, GNN memory and compute requirements scale unfavorably with graph size, and this scaling issue poses a particular challenge within diffusion models, due to their reliance on multiple forward passes to generate a single sample.Some of these molecular diffusion models use coarse-grained representations of their molecular systems to make training and inference computationally feasible [2,4].While [3] and [1] train models with both coarse-grained and all-atom protein representations, their results show superior performance when using all-atom representations at the cost of more expensive/time-consuming training and inference.This is likely because residue-level representations discard precise information regarding the orientation of side chains; information which is critical for modeling binding events [5][6][7].
In developing molecular diffusion models and applying them at scale, researchers must grapple with the trade off between the computational demands and performance afforded by their choice of molecular representation.This work proposes a new choice of molecular representation which simultaneously enjoys the expressiveness of all-atom representations and computational efficiency of coarse-grained representations.In summary, our main contributions are: 1.A novel GNN-based architecture for learning condensed representations of molecular structure, allowing end-to-end training for downstream tasks that operate on these latent geometric representations.
2. A diffusion model for de novo ligand design that achieves a 3-fold increase in inference speed by conditioning ligand generation on a learned representation of protein structure.

Background
Denoising diffusion probabilistic models Diffusion models [8,9] define a forward diffusion process consisting of T noising steps that convert samples from a data distribution at step t = 0 to samples from a prior distribution at step t = T by repeated additions of random noise.The forward diffusion process conditioned on an initial data point x 0 can be defined by Equation 1.
Where α t , σ t ∈ R + are functions that control the amount of signal retained from and noise added to x 0 , respectively.In this work, α t is a function that smoothly transitions from α 0 ≈ 1 to α T ≈ 0. We specifically work with variance-preserving diffusion processes for which α t = 1 − σ 2 t .Equation 1 can be equivalently written as: A neural network that is trained to predict ϵ from noisy data points x t can be used to parameterize a reverse diffusion process p θ (x t−1 |x t ) that converts samples from the prior distribution to samples from the training data distribution.We refer to this neural network as the "noise prediction network" εθ (x t , t).The noise prediction network εθ (z t , t) outputs an E(3)-equivariant vector ε(x) and E(3)-invariant vector ε(s) for each node, representing the noise to be removed from atom positions and features, respectively.

Equivariant diffusion on molecules
Diffusion for protein-ligand complexes Schneuing et al. [1] introduce a conditional EDM for generating small molecules inside of a protein binding pocket, DiffSBDD.Both the ligand and the protein binding pocket are represented as point clouds z (L) and z (P ) , respectively.z (L) is an all-atom point cloud having one node per atom while z (P ) is either an all-atom point cloud or a C α point cloud containing one node for every residue in the binding pocket located at the alpha carbon position.
Schneuing et al. [1] propose two distinct diffusion processes for pocket-conditioned generation.
The first is a conditional diffusion model where the diffusion process is defined only for z (L) ; the noise prediction network takes as input the noisy ligand z (L) t and the pocket structure z (P ) remains unchanged throughout the denoising process: εθ (z (L) t , z (P ) , t).The second method defines a joint diffusion process on both z (L) and z (P ) .εθ is trained to denoise both the ligand and pocket at every timestep εθ (z (L) t , z (P ) t , t), and an inpainting procedure is used to generate ligands inside a given pocket.In both cases, z (L) and z (P ) are passed to εθ as a heterogeneous graph where nodes are atoms or residues and edges are created based on euclidean distance between nodes.

Method
To train a conditional EDM for pocket-conditioned ligand generation as described in [1], the noise prediction network εθ must have access to some representation of the protein binding pocket.Taking inspiration from Ganea et al. [11], we propose to use an encoder E θ (z (P ) ) that accepts an all-atom protein point cloud as input and returns a small, fixed-size point cloud z (KP ) , which we term the "keypoint representation".The receptor encoder and diffusion model can be trained end-to-end by minimizing the denoising loss function Equation 3.
DiffSBDD [1] and EDMs [10] parameterize εθ using the geometric GNN architecture known as EGNN [12].Within the EGNN architecture, nodes possess a single vector feature that, in practice, is designated as the node's position in space.As a result, there is no point in the EGNN architecture where a node retains geometric information describing its local environment.We intuit that EGNNbased architectures may exhibit poor performance on structure representations where a node cannot be adequately described by a single point-mass i.e., residue or fragment point clouds.Specifically, we hypothesize that EGNN may struggle to learn representations of protein structure that are both informative and condensed.To investigate this phenomenon, we train all models with both EGNN and Geometric Vector Perceptron (GVP) [13,14] based architectures.GVP-GNN can be seen as a generalization of EGNN to the setting where nodes can have an arbitrary number of vector features [15].

Pocket encoder module
The pocket encoder module is designed to take an all-atom point cloud of the protein binding pocket z (P ) as input and produce a point cloud z (KP ) = E θ (z (P ) ) having K nodes as output.K is a hyperparameter of the model chosen to be significantly smaller than the number of atoms in a binding pocket.In our training dataset, binding pockets have on the order of hundreds of nodes.We present results for models with K = 40 which is close to the average number of residues in a binding pocket.The nodes of z (KP ) , referred to as keypoints, have positions in space x i ∈ R 3 as well as scalar features s i ∈ R d .When the pocket encoder module is parameterized with GVP-GNN, each keypoint is also endowed with vector features The sequence of operations within the pocket encoder module are summarized in Figure 1.First, message passing is performed along edges between binding pocket atoms.Keypoint nodes are then added to the graph without positions or features.Edges are drawn from receptor nodes to keypoint nodes to form a unidirectional complete bipartite graph.Keypoint positions are obtained via a dot-product variant of graph attention [16] along pocket-keypoint edges.Following keypoint position assignment is a "graph rewiring" step that selectively removes the aforementioned pocket-keypoint edges such that keypoints only have incoming edges from the nearest pocket atoms.Finally, message passing along these local pocket-keypoint edges endows keypoint nodes with spatially localized features.Additional architectural details including equations for graph convolutions and keypoint placement are provided in Appendix B.
Optimal transport loss We find that enforcing spatial alignment between keypoint positions and the true protein/ligand interface is a useful inductive bias.For each protein/ligand pair in the training set we compute a set of interface points x (IP ) ∈ R S×3 that are defined as the median points between all pairs of ligand atoms and binding pocket atoms < 5Å apart.We apply an optimal transport loss function that is minimized when keypoint positions align with the true protein/ligand interface.
Where U(S, K) is the set of transport plans with uniform marginals and ⟨T, C⟩ is the Frobenius inner product between the transport map T and the cost-matrix C. The optimal transport plan is solved in the forward pass using the python optimal transport package [17] and is held fixed during the backwards pass.

Results
Experiments We train all models on the BindingMOAD dataset [18] which contains approximately 40,000 experimentally determined protein/ligand structures from the Protein Data Bank [19].We train baseline models where the ligand point cloud is connected to the input protein point cloud without the use of any keypoint representation.Baseline models are trained with all-atom and C α protein representations.We also train keypoint, all-atom, and C α models with both EGNN and GVP architectures to evaluate the effect of GNN expressivity.
We sample 100 ligands from every pocket in the test set.Generated ligands are subjected to a force-field minimization while holding the binding pocket fixed.We measure the RMSD of the ligand pose before and after minimization.If the ligand is in an unreasonable pose or forming unfavorable interactions with the binding pocket, there will be a larger RMSD upon minimization.Additionally, we use the Autodock Vina scoring function [20] to score the force-field minimized ligands and use the distribution of scores as a proxy for how well ligands are designed for their target pocket.

Generated Ligand Quality
We evaluate ligand quality by cumulative density functions of the RMSD from force-field minimization and Vina score shown in Figure 2, with higher CDF values indicating higher quality ligands for both metrics.Most notably, the GVP keypoint model performance is comparable to the all-atom models despite using 10x fewer nodes to represent the binding pocket.Models using C α binding pocket representations produce ligands of lower quality than those that use all-atom pocket representations; this is consistent with prior works [1,3].The EGNN keypoint model produces ligands of equivalent quality to that of C α models.
Inference Performance We sample 100 molecules per pocket for 10 binding pockets and report mean wall-time per binding pocket.Sampling times in Figure 2 show that keypoint models are 3x faster than their corresponding all-atom models.Additional results in Appendix D.2 show that we can trade-off inference time and ligand quality by changing the number of keypoints.

Conclusions
Our receptor encoder module is capable of learning compressed representations of binding pocket structure which enables a 3x reduction in inference time while maintaining comparable quality of generated ligands.Our receptor encoder module may serve as a useful tool for scaling inference in molecular diffusion models.Moreover, our work demonstrates that learned structure encoders can provide valuable flexibility to trade-off computational demands and model performance.
The GVP keypoint model was able to approach all-atom levels of performance while the EGNN keypoint model failed to exceed the performance C α representations.This result supports our hypothesis that EGNN struggles to learn on molecular representations where a single node represents multiple atoms and may serve as practical guidance for practitioners designing geometric deep learning models for molecular structure.
Models are trained and sampled with T = 1000 diffusion steps and the same polynomial noise schedule used in [1,10].
Force-field minimization is done using the UFF force-field implemented in RDKit [21].
For inference time analysis, inference is run using the same batch size for all models on the same GPU.
The baseline models are conceptually/functionally equivalent to DiffSBDD; however, we do not use their implementation.The baseline models are our re-implementation of DiffSBDD; this enables meaningful comparisons that control for factors outside of the protein representation e.g., minor architectural differences, inference time differences due to code efficiency.

B Architecture
We operate on point clouds.Every point in the cloud is a node.Specifically, there are three distinct point clouds present in this work: the ligand point cloud z (L) , the protein binding pocket point cloud z (P ) , and the keypoint point cloud produced by the receptor encoder z (KP ) = E θ (z (P ) ).These three point clouds are combined into a single heterogeneous graph having three node types: protein nodes, keypoint nodes, and ligand nodes.Edges in the graph are directed, resulting in 9 possible unique edge types.
When using the EGNN architecture, each node i in the point cloud is endowed with a position in space x i ∈ R 3 and scalar features s i ∈ R d .For all-atom point clouds, s i are initialized as one-hot vectors of atom elements.For C α point clouds, s i are initialized as one-hot vectors of the amino acid identity.When using the GVP-GNN architecture, nodes are additionally endowed with vector features v i ∈ R c×3 which are initialized to zeros.
A graph convolution is defined as one round of message passing, message aggregation, and node feature updates.A single graph convolution makes up one layer of a GNN.Both the receptor encoder and the noise prediction models stack several graph convolution layers.
Separate message-generating and node-update functions are instantiated for each edge and node type, each having their own set of learnable parameters.Within the receptor-encoder, graph convolutions are only done along one edge type at a time: P → P edges followed by KP → P edges.Within the noise prediction network, graph convolutions are performed along the following edge types simultaneously: KP → L, L → KP , KP → KP , and L → L.

B.1 EGNN graph convolution
EGNN computes separate messages for scalar and position features, which are defined by Equations 5 and 6, respectively.
Where d ij is the Euclidean distance between nodes i and j.The functions ϕ s : R 2d+1 → R d and ϕ x : R 2d+1 → R are implemented as shallow multi-layer perceptrons (MLPs).Incoming messages are aggregated on each node and used to update the scalar features and positions of each node. x Note that we are overloading the node feature notation here.Previously the superscript above a node feature was used to indicate the type of node (ligand, keypoint, protein).In Equations 7 and 8, the superscript above node features indicates the layer of network: s (l) i and x (l) i are node i's scalar and position features at layer l.C is average in-degree of the graph, and N (i) is the set nodes which have edges pointing to node i.
Layer normalization is applied to scalar features after each graph convolution.
For a noise prediction network containing L EGNN convolutions, the predicted position noise is obtained by subtracting the layer-0 node positions from the layer-L node positions: ε(x) = x (L) −x (0) .The predicted noise for atom features is obtained by passing the scalar features for each node at layer L through a shallow MLP with a linear output layer:

B.2 GVP-GNN graph convolution
Our implementation of GVP-based graph convolutions do not update node positions; rather, they update scalar and vector features.GVPs accept and return a tuple of scalar and vector features.Therefore, scalar and vector messages m Where : denotes concatenation, and rbf (d ij ) is a radial basis function (RBF) embedding of the distance between nodes i and j.
The node-update function is defined exactly as described in [14]: Where LN and DO are dropout and layer norm, respectively.The node update function g u is a chain of three GVPs.

B.3 GVP noise-prediction block
After GVP graph convolutions, every ligand node is in possession of scalar and vector features, s i and v i , which contain information about the ligand atom's environment; these features for each ligand atom are passed into "noise-prediction block" which returns noise predictions ε(s) , ε(x) .
Within the noise prediction block, s i and v i are first passed through a chain of 4 GVPs.The first three GVPs output the same number of scalar and vector features as are in the input.The final GVP outputs the input number of scalar features but only 1 vector feature and the vector-gating activation function, which is typically a sigmoid function, is replaced with the identity.The output vector feature is ε(x) .The output scalar features are passed through an additional MLP to produce a scalar vector of the same shape as the one-hot encoded atom type vectors; this output becomes ε(s) .

B.4 Keypoint placement via graph attention
After several graph convolutions along P → P edges, z (P ) nodes possess features describing their local atomic environment which will be used to determine the positions of keypoints.K keypoints are added to the graph but they initially have no positions, scalar features, or vector features associated with them.Each node in z (P ) is a given an out-going edge to every keypoint; thus the nodes in z (P ) and z (KP ) form a unidirectional complete bipartite graph.
Keypoint positions are then computed as a weighted sum of pocket node positions.

D.3 Additional ligand quality metrics
Prior works on de novo ligand generation have reported QED (drug-likeness), SA (synthetic accessibility), and diversity metrics on generated ligands.Although we maintain that ligand-only metrics are less critical for evaluating generative models than metrics that describe ligand/pocket interactions, we nevertheless provide these metrics for ligands generated by our model in Table 2. Diversity is the average value of the complement of the Tanimoto score over all pairs of ligands generated for a given pocket.

Figure 1 :
Figure 1: Message passing is performed between receptor nodes.Learned receptor embeddings are used to place keypoints inside the binding pocket.Keypoints extract local features of the binding pocket.Keypoints are then used to condition the ligand generation process.

Figure 2 :
Figure 2: Left, Middle: CDFs of ligand RMSD from force-field minimization and Vina score.Right: Sampling time per molecule averaged over the same ten binding pockets for each model.
i→j are generated by a single function g m which is two GVPs chained together. m

Figure 3 :
Figure 3: Left, Middle: CDFs of ligand RMSD from force-field minimization and Vina score.Right: Sampling time per molecule averaged over the same ten binding pockets for each model.